from xpuyu.modelings.internlm_moe import InternLM3MoEConfig, InternLM3MoEForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer

from xpuyu.modelings.qwen_moe import Qwen2MoeConfig, Qwen2MoeForCausalLM
from xtuner._lite.accelerate import LoadWoInit

# with LoadWoInit():
#     llm = AutoModelForCausalLM.from_pretrained(
#         '/fs-computility/llm/shared/caoweihan/models/qwen_moe_lite',
#         trust_remote_code=True,
#         torch_dtype=torch.bfloat16,
#         attn_implementation='flash_attention_2')

def cal(llm, cfg):
    numel_act = 0
    numel_total = 0
    numel_moe = 0
    numel_wo_moe = 0
    numel_attn = 0
    numel_act_moe = 0
    for name, param in llm.named_parameters():
        if 'expert' in name:
            numel_moe += param.numel()
        else:
            numel_wo_moe += param.numel()
        if '.experts.' in name:
            print(name)
            numel_act += param.numel(
            ) * cfg.num_experts_per_tok / (cfg.num_routed_experts if hasattr(cfg, 'num_routed_experts') else cfg.n_routed_experts)
        else:
            numel_act += param.numel()
        if 'attention' in name or 'self_attn' in name:
            numel_attn += param.numel()
        numel_total += param.numel()
        # print(name, param.numel()/1e9)
    print(
        f'Total act param: {numel_act / 1e9}, Total param: {numel_total / 1e9}, MoE param: {numel_moe / 1e9}, '
        f'Other param: {numel_wo_moe / 1e9}, Attn param: {numel_attn / 1e9}, MoE act param: {(numel_act - numel_wo_moe) / 1e9}'
    )

# model = Qwen2MoeForCausalLM.from_pretrained('/fs-computility/llm/shared/caoweihan/models/qwen_moe_lite', 
#     torch_dtype=torch.bfloat16, attn_implementation='flash_attention_2', )

# cal(model, model.config)

# breakpoint()


# config = Qwen2MoeConfig(
#     vocab_size=128133,
#     hidden_size=2048,
#     num_hidden_layers=32,
#     num_attention_heads=16,
#     num_key_value_heads=2,
#     moe_intermediate_size=1536,
#     num_experts_per_tok=8,
#     n_routed_experts=64,
#     # num_shared_experts=0,
#     output_router_logits=True,
#     shared_expert_intermediate_size=0
# )

# model = Qwen2MoeForCausalLM._from_config(config, attn_implementation='flash_attention_2')
# breakpoint()
# model.cuda()
# model.to(torch.bfloat16)
# model.config.use_cache = False

# ids = torch.randint(0, 1000, (1, 1024, )).cuda()
# out = model(input_ids=ids)

# # model = AutoModelForCausalLM.from_pretrained('internlm3_moe_lite/', attn_implementation='flash_attention_2', trust_remote_code=True)

# model.save_pretrained('/fs-computility/llm/shared/caoweihan/models/qwen_moe_lite')
# # tok = AutoTokenizer.from_pretrained('/cpfs01/shared/public/caoweihan/tokenizer_1105_120k', use_fast=False, trust_remote_code=True)
# # tok.save_pretrained('internlm3_moe_lite')
# breakpoint()


config = InternLM3MoEConfig(
    vocab_size=128512,
    hidden_size=2048,
    num_hidden_layers=32,  # 
    num_attention_heads=32,
    num_key_value_heads=4,
    intermediate_size=1536,
    num_experts_per_tok=8,
    num_experts=64,
    num_shared_experts=0,
    bias=False,
    router_z_loss_coef=0,
    balancing_loss_coef=0.01,
    rope_theta=100000,
    max_position_embeddings=32768
)


model = InternLM3MoEForCausalLM._from_config(config, attn_implementation='flash_attention_2')
model.cuda()
model.to(torch.bfloat16)
model.config.use_cache = False

ids = torch.randint(0, 1000, (1, 1024, )).cuda()
out = model(input_ids=ids)

# model = AutoModelForCausalLM.from_pretrained('internlm3_moe_lite/', attn_implementation='flash_attention_2', trust_remote_code=True)

model.save_pretrained('internlm3_moe_lite_new2_128512_fix_init')
tok = AutoTokenizer.from_pretrained('internlm3_moe_lite_new', use_fast=False, trust_remote_code=True)
tok.save_pretrained('internlm3_moe_lite_new2_128512_fix_init')
breakpoint()